# Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for open_spiel.python.algorithms.exploitability."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import unittest

from open_spiel.python import policy
from open_spiel.python.algorithms import exploitability
from open_spiel.python.games import data
import pyspiel


class ExploitabilityTest(unittest.TestCase):

  def test_exploitability_on_kuhn_poker_uniform_random(self):
    # NashConv of uniform random test_policy from (found on Google books):
    # https://link.springer.com/chapter/10.1007/978-3-319-75931-9_5
    game = pyspiel.load_game("kuhn_poker")
    test_policy = policy.UniformRandomPolicy(game)
    expected_nash_conv = 11 / 12
    self.assertAlmostEqual(
        exploitability.exploitability(game, test_policy),
        expected_nash_conv / 2)

  def test_kuhn_poker_uniform_random_best_response_pid0(self):
    game = pyspiel.load_game("kuhn_poker")
    test_policy = policy.UniformRandomPolicy(game)
    results = exploitability.best_response(game, test_policy, player_id=0)
    self.assertEqual(
        results["best_response_action"],
        {
            "0": 1,  # Bet in case opponent folds when winning
            "1": 1,  # Bet in case opponent folds when winning
            "2": 0,  # Both equally good (we return the lowest action)
            # Some of these will never happen under the best-response policy,
            # but we have computed best-response actions anyway.
            "0pb": 0,  # Fold - we're losing
            "1pb": 1,  # Call - we're 50-50
            "2pb": 1,  # Call - we've won
        })
    self.assertGreater(results["nash_conv"], 0.1)

  def test_kuhn_poker_uniform_random_best_response_pid1(self):
    game = pyspiel.load_game("kuhn_poker")
    test_policy = policy.UniformRandomPolicy(game)
    results = exploitability.best_response(game, test_policy, player_id=1)
    self.assertEqual(
        results["best_response_action"],
        {
            # Bet is always best
            "0p": 1,
            "1p": 1,
            "2p": 1,
            # Call unless we know we're beaten
            "0b": 0,
            "1b": 1,
            "2b": 1,
        })
    self.assertGreater(results["nash_conv"], 0.1)

  def test_kuhn_poker_uniform_random(self):
    # NashConv of uniform random test_policy from (found on Google books):
    # https://link.springer.com/chapter/10.1007/978-3-319-75931-9_5
    game = pyspiel.load_game("kuhn_poker")
    test_policy = policy.UniformRandomPolicy(game)
    self.assertAlmostEqual(exploitability.nash_conv(game, test_policy), 11 / 12)

  def test_kuhn_poker_always_fold(self):
    game = pyspiel.load_game("kuhn_poker")
    test_policy = policy.FirstActionPolicy(game)
    self.assertAlmostEqual(exploitability.nash_conv(game, test_policy), 2)

  def test_kuhn_poker_optimal(self):
    game = pyspiel.load_game("kuhn_poker")
    test_policy = data.kuhn_nash_equilibrium(alpha=0.2)
    self.assertAlmostEqual(exploitability.nash_conv(game, test_policy), 0)

  def test_leduc_poker_uniform_random(self):
    # NashConv taken from independent implementations
    game = pyspiel.load_game("leduc_poker")
    test_policy = policy.UniformRandomPolicy(game)
    self.assertAlmostEqual(
        exploitability.nash_conv(game, test_policy), 4.747222222222222)

  def test_leduc_poker_always_fold(self):
    game = pyspiel.load_game("leduc_poker")
    test_policy = policy.FirstActionPolicy(game)
    self.assertAlmostEqual(exploitability.nash_conv(game, test_policy), 2)

  # Values for uniform policies taken from
  # https://link.springer.com/chapter/10.1007/978-3-319-75931-9_5
  # (including multiplayer games below). However, the value for Leduc against
  # the uniform test_policy is wrong in the paper. This has been independently
  # verified in a number of independent code bases. The 4.7472 value is correct.
  # Value for AlwaysFold is trivial: if you
  # always fold, you win 0 chips, but if you switch to AlwaysBet, you win 1
  # chip everytime if playing against a player who always folds.
  def test_2p_nash_conv(self):
    # Note: The first action test_policy is "AlwaysFold".
    kuhn_poker = pyspiel.load_game("kuhn_poker")
    leduc_poker = pyspiel.load_game("leduc_poker")
    test_parameters = [
        (kuhn_poker, policy.UniformRandomPolicy(kuhn_poker),
         0.9166666666666666),
        (kuhn_poker, policy.FirstActionPolicy(kuhn_poker), 2.),
        (kuhn_poker, data.kuhn_nash_equilibrium(alpha=0.2), 0.),
        (leduc_poker, policy.FirstActionPolicy(leduc_poker), 2.),
        (leduc_poker, policy.UniformRandomPolicy(leduc_poker),
         4.747222222222222),
    ]
    for game, test_test_policy, expected_value in test_parameters:
      self.assertAlmostEqual(
          exploitability.nash_conv(game, test_test_policy), expected_value)

  # TODO add test with values from:
  # http://poker.cs.ualberta.ca/publications/AAMAS13-3pkuhn.pdf

  def test_kuhn_poker_3p_uniform_random_nash_conv(self):
    game = pyspiel.load_game("kuhn_poker",
                             {"players": pyspiel.GameParameter(3)})
    test_policy = policy.UniformRandomPolicy(game)
    self.assertGreater(exploitability.nash_conv(game, test_policy), 2)

  def test_kuhn_poker_4p_uniform_random_nash_conv(self):
    game = pyspiel.load_game("kuhn_poker",
                             {"players": pyspiel.GameParameter(4)})
    test_policy = policy.UniformRandomPolicy(game)
    self.assertGreater(exploitability.nash_conv(game, test_policy), 3)


if __name__ == "__main__":
  unittest.main()
